// Copyright 2008 Google Inc. All rights reserved.
// See LICENSE for details of Apache 2.0 license.

import java.text.CharacterIterator;
import java.text.StringCharacterIterator;
import java.util.List;
import java.util.LinkedList;
import java.io.*;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.InvocationTargetException;

/**
 * This class implements a test driver that executes Virgil test programs with
 * their specified inputs, checking them against their specified outputs. This
 * allows the JVM code generated by Aeneas to be tested on the JVM.
 */
public class V3S_Tester {

    private static class Run {
        final Object[] inputs;
        final Object result;
        final Class exception;

        Run(Object[] inputs, Object result, Class exception) {
            this.inputs = inputs;
            this.result = result;
            this.exception = exception;
        }

        String expected() {
            if (exception != null) {
                return "!" + exception.getName();
            }
            return result.toString();
        }

        String inputs() {
            StringBuilder sb = new StringBuilder();
            if (inputs.length > 0) sb.append('(');
            for (int i = 0; i < inputs.length; i++) {
                if (i > 0) sb.append(',');
                sb.append(inputs[i]);
            }
            if (inputs.length > 0) sb.append(')');
            return sb.toString();
        }
    }

    public static void main(String[] args) throws Exception {
	int count = args.length;
	for (String fname : args) {
	    if (fname.charAt(0) == '-') count--;
	}
        ProgressPrinter progress = new ProgressPrinter(count, 1);
        for (String fname : args) {
            runTest(fname, progress);
        }
        progress.report();
    }

    private static void runTest(String fname, ProgressPrinter progress) {
        try {
	    if (fname.charAt(0) == '-') return;
            progress.begin(fname);
            File file = new File(fname);
            BufferedReader is = new BufferedReader(new FileReader(file));
            List<Run> runs = parseSpecLine(new StringCharacterIterator(is.readLine()));
            is.close();
            runTests(findVirgilEntrypoint(getJavaClassName(file.getName())), runs);
            progress.pass();
        } catch (Throwable t) {
            progress.fail(t.toString());
        }
    }

    private static void runTests(Method testMethod, List<Run> runs) throws Exception {
        testMethod.setAccessible(true);
        int num = 1;
        for (Run run : runs) {
            try {
		coerceInputs(testMethod.getParameterTypes(), run.inputs);
                Object result = testMethod.invoke(null, run.inputs);
                if (run.result == null || !compare(run.result, result)) {
                    throw new Exception(run.inputs() + "=" + result + ", expected: " + run.expected());
                }
            } catch (InvocationTargetException e) {
                Throwable r = e.getCause();
                if (r.getClass() != run.exception)
                    throw new Exception(run.inputs() + "=!" + r.getClass().getName() + " (" + r.getMessage() + "), expected: " + run.expected());
            }
            num++;
        }
    }

    private static void coerceInputs(Class[] types, Object[] inputs) {
	for (int i = 0; i < types.length; i++) {
	    Class t = types[i];
	    Object o = inputs[i];
	    if (t == byte.class) inputs[i] = new Byte((byte) intValue(o));
	    if (t == short.class) inputs[i] = new Short((short) intValue(o));
	    if (t == char.class) inputs[i] = new Character((char) intValue(o));
	}
    }

    private static boolean compare(Object o1, Object o2) {
	if (o1 instanceof Boolean) return o1.equals(o2);
	if (o2 instanceof Boolean) return o2.equals(o1);
	if (isIntegral(o1) && isIntegral(o2)) {
	  int mask = maskOf(o1) & maskOf(o2);
	  return (intValue(o1) & mask) == (intValue(o2) & mask);
	}
	return o1.equals(o2);
    }

    private static int intValue(Object o) {
	if (o == null) return 0;
	if (o instanceof Byte) return ((Byte) o).byteValue();
	if (o instanceof Short) return ((Short) o).shortValue();
	if (o instanceof Character) return ((Character) o).charValue();
	if (o instanceof Integer) return ((Integer) o).intValue();
	throw new Error("not an integral value: " + o);
    }

    private static boolean isIntegral(Object o) {
	return (o instanceof Integer) || (o instanceof Byte) || (o instanceof Character) || (o instanceof Short);
    }

    private static int maskOf(Object o) {
	if (o instanceof Character || o instanceof Short) return 0xFFFF;
	if (o instanceof Byte) return 0xFF;
	return 0xFFFFFFFF;
    }

    private static Class getJavaException(String name) throws Exception {
        if ("NullCheckException".equals(name)) return NullPointerException.class;
        if ("DivideByZeroException".equals(name)) return ArithmeticException.class;
        if ("BoundsCheckException".equals(name)) return ArrayIndexOutOfBoundsException.class;
        if ("TypeCheckException".equals(name)) return ClassCastException.class;
        if ("UnimplementedException".equals(name)) return AbstractMethodError.class;
        if ("LengthCheckException".equals(name)) return NegativeArraySizeException.class;
        throw new Exception("Unknown exception class " + name);
    }

    public static List<Run> parseSpecLine(CharacterIterator iter) throws Exception {
        List<Run> list = new LinkedList<Run>();
        skipWhiteSpace(iter);
        expect(iter, "//@execute");
        skipWhiteSpace(iter);

        if (option(iter, '=')) {
            // a single run.
            skipWhiteSpace(iter);
            list.add(parseResult(iter, new Object[0]));
        } else {
            // multiple runs.
            do {
                skipWhiteSpace(iter);
                list.add(parseRun(iter));
                skipWhiteSpace(iter);
            } while (option(iter, ';'));
        }

        return list;
    }

    private static Run parseRun(CharacterIterator iter) throws Exception {
        Object[] inputs;
        if (option(iter, '(')) {
            inputs = parseValues(iter);
        } else {
            inputs = new Object[] { parseValue(iter) };
        }
        skipWhiteSpace(iter);
        expect(iter, '=');
        skipWhiteSpace(iter);
        return parseResult(iter, inputs);
    }

    private static Object[] parseValues(CharacterIterator iter) throws Exception {
        List<Object> input = new LinkedList<Object>();
        while (iter.current() != CharacterIterator.DONE) {
            skipWhiteSpace(iter);
            input.add(parseValue(iter));
            skipWhiteSpace(iter);
            if (option(iter, ')')) break;
            else expect(iter, ',');
        }
        return input.toArray();
    }

    private static Run parseResult(CharacterIterator iter, Object[] input) throws Exception {
        if (iter.current() == '!') {
            iter.next();
            return new Run(input, new Object(), parseException(iter));
        } else {
            Object result = parseValue(iter);
            return new Run(input, result, null);
        }
    }

    public static Object parseValue(CharacterIterator iter) throws Exception {
        switch (iter.current()) {
            case '-':
            case '0':
            case '1':
            case '2':
            case '3':
            case '4':
            case '5':
            case '6':
            case '7':
            case '8':
            case '9': return parseInt(iter);
            case '\'': return parseChar(iter);
            case 't': expect(iter, "true"); return Boolean.TRUE;
            case 'f': expect(iter, "false"); return Boolean.FALSE;
        }
        throw parseError(iter, "invalid value");
    }

    public static void skipWhiteSpace(CharacterIterator i) {
        while (true) {
            char c = i.current();
            if (c != ' ' && c != '\n' && c != '\t') break;
            i.next();
        }
    }
    public static void expect(CharacterIterator i, char c) throws Exception {
        char r = i.current();
        if (r != c) parseError(i, "expected character \'" + c + "\'");
        i.next();
    }

    public static void expect(CharacterIterator iter, String s) throws Exception {
        for (int i = 0; i < s.length(); i++) expect(iter, s.charAt(i));
    }

    public static boolean option(CharacterIterator i, char c) {
        char r = i.current();
        if (r == c) {
            i.next();
            return true;
        }
        return false;
    }

    private static Integer parseInt(CharacterIterator iter) throws Exception {
        StringBuffer buf = new StringBuffer();

        if (option(iter, '-')) buf.append('-');

        for (int cntr = 0; cntr < 10; cntr++) {
            char c = iter.current();

            if (!Character.isDigit(c)) break;

            buf.append(c);
            iter.next();
        }

        try {
            return Integer.parseInt(buf.toString());
        } catch (NumberFormatException e) {
            throw parseError(iter, "invalid integer");
        }
    }

    private static Byte parseChar(CharacterIterator iter) throws Exception {
        char next = iter.next();
        if (next == '\\') {
            next = iter.next();
            switch (next) {
                case '\'': next = '\''; break;
                case '\"': next = '\"'; break;
                case '\\': next = '\\'; break;
                case 'b': next = '\b'; break;
                case 'n': next = '\n'; break;
                case 'r': next = '\r'; break;
                case 't': next = '\t'; break;
		case 'x': {
		    int vh = hexVal(iter);
		    int vl = hexVal(iter);
		    iter.next();
                    expect(iter, '\'');
                    return (byte) ((vh << 4) | vl);
                }
                default: throw parseError(iter, "invalid escaped character");
            }
        }
        iter.next();
        expect(iter, '\'');
        return (byte) next;
    }

    private static int hexVal(CharacterIterator iter) throws Exception {
	char ch = iter.next();
	if (ch <= '9' && ch >= '0') return ch - '0';
	if (ch <= 'f' && ch >= 'a') return 10 + ch - 'a';
	if (ch <= 'F' && ch >= 'A') return 10 + ch - 'A';
        throw parseError(iter, "invalid hex character");
    }

    private static Class parseException(CharacterIterator iter) throws Exception {
        StringBuffer buf = new StringBuffer();
        while (iter.current() != CharacterIterator.DONE) {
            char c = iter.current();

            if (!Character.isLetter(c)) break;

            buf.append(c);
            iter.next();
        }
        return getJavaException(buf.toString());
    }

    private static Exception parseError(CharacterIterator iter, String msg) throws Exception {
        StringBuffer buffer = new StringBuffer("Parse error: "+ msg + "\n");
        int max = iter.getEndIndex();
        int pos = iter.getIndex();
        iter.first();
        for (int i = iter.getBeginIndex(); i < max; i++) {
            buffer.append(iter.current());
            iter.next();
        }
        buffer.append("\n");
        for (int i = iter.getBeginIndex(); i < pos; i++) {
            buffer.append(' ');
        }
        buffer.append("^\n");
        throw new Exception(buffer.toString());
    }

    /**
     * Helper class for pretty colored output.
     */
    public static class ProgressPrinter {

        private static final String CTRL_RED = "\u001b[0;31m";
        private static final String CTRL_GREEN = "\u001b[0;32m";
        private static final String CTRL_NORM = "\u001b[0;00m";

        public final int total;
        public final int verbose;
        private String current;
        private int passed;
        private int finished;

        private final PrintStream output = System.out;
        private final List<String> failures = new LinkedList<String>();

        public ProgressPrinter(int total, int verbose) {
            this.total = total;
            this.verbose = verbose;
        }

        public void begin(String item) {
            current = item;
            if (verbose == 2) output.print("Running " + item + "...");
        }

        public void pass() {
            passed++;
            if (verbose > 0) output(CTRL_GREEN, 'o', "ok");
        }

        public void fail(String msg) {
            if (verbose > 0) output(CTRL_RED, 'X', "failed");
            if (verbose == 1) failures.add(CTRL_RED + current + CTRL_NORM + ": " + msg);
            if (verbose == 2) this.output.println(" -> " + msg);
        }

        private void output(String ctrl, char ch, String str) {
            finished++;
            if (verbose == 1) {
                output.print(ctrl);
                output.print(ch);
                output.print(CTRL_NORM);
                if (finished % 50 == 0 || finished == total) this.output.print(" " + finished + " of " + total + "\n");
                else if (finished % 10 == 0) this.output.print(' ');
            } else if (verbose == 2) {
                output.print(ctrl);
                output.print(str);
                output.print(CTRL_NORM);
                output.println("");
            }
        }

        public void report() {
            output.println(passed + " of " + total +" passed");
            if (verbose == 1) {
                for (String s : failures) {
                    output.println(s);
                }
            }
        }
    }

    public static String getJavaClassName(String className) {
        int ext = className.indexOf(".v3");
        if (ext > 0 && ext == className.length() - 3) className = className.substring(0, ext);
        int sep = className.lastIndexOf(File.separatorChar);
        if (sep > 0) className = className.substring(sep + 1);
        return "V3K_" + className;
    }

    public static Method findVirgilEntrypoint(String className) throws Exception {
        Class testClass;
        try {
            testClass = Class.forName(className);
        } catch (ClassNotFoundException e) {
            throw new Exception("could not find class " + className);
        }
        for (Method method : testClass.getDeclaredMethods()) {
            if ((method.getModifiers() & Modifier.STATIC) != 0 && "main".equals(method.getName()))
                return method;
        }
        throw new Exception("could not find main method in class " + className);
    }
}
